-
Notifications
You must be signed in to change notification settings - Fork 23
Enable AITER ASM distributed FA testing in jax/torch #363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
|
Which functionality not covered by existing tests does it cover? |
Previously our jax and pytorch distributed fused-attn only enables v2 ck backends, not v3 |
Yes, but does it run different fused attn backend configs/kernels then non distributed ones? Or there is functionality concern of coexistence of them with RCCL? |
In the distributed fused-attn (CP) pytest suite, the reference run is usually a single-GPU fused-attn with full seqlen (for example sq=skv=8192) using the default attn backend. The target run decomposes the single full-size fused-attn into 4 or 8 smaller fused-attn (for example, sq=sk=4096), runs those smaller fused-attn instances using the default backend and then "glue" the results in the CP way. In my option, why we need to enable v3 for distributed fused-attn: |
ci/jax.sh
Outdated
| *0.4.35*) | ||
| # Workaround for distributed tests hang with xla_flag | ||
| XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn' | ||
| XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will run it with AOTriton too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated with a guard in the JAX ci script
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With those changes env variables are not seen by run method - they are applied to test call only.
Using run_default_fa_lbl. All V3 calls should be labelled with "v3" to distinct them from regular test_distributed_fused_attn call
|
@Micky774 Could you rebase upon latest dev to incorporate the hot fix for the core sgpu tests? |
|
@ipanfilo Could you check if all your comments have been addressed? |
ci/jax.sh
Outdated
| *0.4.35*) | ||
| # Workaround for distributed tests hang with xla_flag | ||
| XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn' | ||
| XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With those changes env variables are not seen by run method - they are applied to test call only.
Using run_default_fa_lbl. All V3 calls should be labelled with "v3" to distinct them from regular test_distributed_fused_attn call
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: